import pandas as pd
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb

from torch.utils.data import DataLoader, Dataset
import torchvision as tv, torchvision.transforms as tr

home_dir = os.path.expanduser('~')
human_conf_path = os.path.join(home_dir, 'data/cifar10/cifar10h-probs.npy')

def calibration(model, dload_real, save_dir, sigma=0., device='cuda', args=None):
    real_scores, second_scores = [], [] 
    labels, pred, logits_l = [], [], []
    human_conf = np.load(human_conf_path)
    
    ece_com = ECELoss(20)
    ece, c = 0, 0
    for x, y in dload_real:
        x = x.to(device)
        labels.append(y.numpy())
        logits = model(x)

        logits_l.append(logits.detach())
        scores, preds = nn.Softmax(dim=1)(logits).topk(k=2, dim=1)
        top1_scores = scores[:, 0].detach().cpu()
        top2_scores = scores[:, 1].detach().cpu()
        preds = preds[:, 0].detach().cpu()

        real_scores.append(top1_scores.numpy())
        second_scores.append(top2_scores.numpy())
        pred.append(preds.numpy())
    
    logits_l = torch.cat(logits_l)
    temps = torch.LongTensor(np.concatenate(labels))
    ece = ece_com(logits_l, temps.to(device)).item()
    prob_bin, avg_conf, ece_bin, accs, diff = ece_com.get_bin() 
    ece_com.init()

    real_scores = np.concatenate(real_scores)
    second_scores = np.concatenate(second_scores)
    scores_gap = real_scores - second_scores
    labels = np.concatenate(labels, axis=0)
    pred = np.concatenate(pred)

    hlabels = human_conf.argmax(-1)
    ece_h = ece_com(logits_l.cuda(), torch.from_numpy(hlabels).cuda()).item()
    h_prob_bin, h_avg_conf, h_ece_bin, h_accs, h_diff = ece_com.get_bin() 
    ece_com.init()

    df = pd.DataFrame([accs, h_accs, h_avg_conf, diff, h_diff, prob_bin, ece_bin, h_ece_bin]).T
    df.columns = ['acc(M)', 'acc(H)', 'Confidence', 'DIFF(M)', 'DIFF(H)', 'Prop bin', 'ECE(M)', 'ECE(H)']
    df.index = ['0.0~0.05', '0.05~0.1', '0.1~0.15', '0.15~0.2', 
                '0.2~0.25', '0.25~0.3','0.3~0.35','0.35~0.4',
                '0.4~0.45', '0.45~0.5','0.5~0.55','0.55~0.6',
                '0.6~0.65', '0.65~0.7', '0.7~0.75', '0.75~0.8',
                '0.8~0.85','0.85~0.9','0.9~0.95',
                '0.95~0.96','0.96~0.97','0.97~0.98','0.98~0.99',
                '0.99~1.0']
    df.to_csv(os.path.join(args.output_path, 'calibration.csv'))

    correct = (pred == labels).sum() / len(labels)
    hcorrect = (pred == hlabels).sum() / len(labels)
    wandb.log({**{
        f"calibration/ece": ece,
        f"calibration/acc": correct,
        f"calibration/acc_h": hcorrect,
    }}, commit=False)
    print(f"ECE: {ece}, ECE(H): {ece_h}, Acc: {correct}, HAcc: {hcorrect}")
    reliability_diagrams(
        list(pred), list(labels), list(real_scores),
        bin_size=0.05, save_dir=os.path.join(save_dir, 'ece_M-conf-M-acc.png'),
        title="Accuracy: %.2f%%" % (100.0 * correct))
    # reliability_diagrams(
        # list(pred), list(labels), list(scores_gap),
        # bin_size=0.05, save_dir=os.path.join(save_dir, 'ece_M-conf-gap-M-acc.png'),
        # title="Accuracy: %.2f%%" % (100.0 * correct))
    if args.dataset == 'cifar10':
        reliability_diagrams(
            list(pred), list(hlabels), list(real_scores),
            bin_size=0.05, save_dir=os.path.join(save_dir, 'ece_M-conf-H-acc.png'),
            title="Accuracy: %.2f%%" % (100.0 * hcorrect))
        # reliability_diagrams(
        #     list(pred), list(hlabels), list(scores_gap),
        #     bin_size=0.05, save_dir=os.path.join(save_dir, 'ece_M-conf-gap-H-acc.png'),
        #     title="Accuracy: %.2f%%" % (100.0 * hcorrect))
    return ece, ece_h, hcorrect, df

def expected_calibration_error(predictions, truths, confidences, bin_size=0.1, title='demo'):

    upper_bounds = np.arange(bin_size, 1 + bin_size, bin_size)
    accs = []

    # Compute empirical probability for each bin
    plot_x = []
    ece = 0
    for conf_thresh in upper_bounds:
        acc, perc_pred, avg_conf = compute_accuracy(conf_thresh - bin_size, conf_thresh, confidences, predictions, truths)
        plot_x.append(avg_conf)
        accs.append(acc)
        ece += abs(avg_conf - acc) * perc_pred
    return ece


def reliability_diagrams(
        predictions, truths, confidences, 
        bin_size=0.1, title='demo', save_dir='./calibration'):
    upper_bounds = np.arange(bin_size, 1 + bin_size, bin_size)
    accs, bin_counts = [], []

    # Compute empirical probability for each bin
    conf_x = []
    ece = 0
    for conf_thresh in upper_bounds:
        acc, perc_pred, avg_conf = compute_accuracy(conf_thresh - bin_size, conf_thresh, confidences, predictions, truths)
        conf_x.append(avg_conf)
        accs.append(acc)
        bin_counts.append(calculate_count_bin(conf_thresh - bin_size, conf_thresh, confidences))
        temp = abs(avg_conf - acc) * perc_pred
        # print('m %.2f, B_m %d, acc(B_m) %.4f, conf = %.4f, |B_m||acc(B_m) - conf(B_m)|/n = %.5f' % (conf_thresh, int(perc_pred * len(predictions)), acc, avg_conf, temp))
        ece += temp

    # Produce error bars for each bin
    upper_bound_to_bootstrap_est = {x: [] for x in upper_bounds}
    for i in range(1):

        # Generate bootstrap
        boot_strap_outcomes = []
        boot_strap_confs = random.sample(confidences, len(confidences))
        for samp_conf in boot_strap_confs:
            correct = 0
            if random.random() < samp_conf:
                correct = 1
            boot_strap_outcomes.append(correct)

        # Compute error frequency in each bin
        for upper_bound in upper_bounds:
            conf_thresh_upper = upper_bound
            conf_thresh_lower = upper_bound - bin_size

            filtered_tuples = [x for x in zip(boot_strap_outcomes, boot_strap_confs) if x[1] > conf_thresh_lower and x[1] <= conf_thresh_upper]
            correct = len([x for x in filtered_tuples if x[0] == 1])
            acc = float(correct) / len(filtered_tuples) if len(filtered_tuples) > 0 else 0

            upper_bound_to_bootstrap_est[upper_bound].append(acc)

    upper_bound_to_bootstrap_upper_bar = {}
    upper_bound_to_bootstrap_lower_bar = {}
    for upper_bound, freqs in upper_bound_to_bootstrap_est.items():
        top_95_quintile_i = int(0.975 * len(freqs))
        lower_5_quintile_i = int(0.025 * len(freqs))

        upper_bar = sorted(freqs)[top_95_quintile_i]
        lower_bar = sorted(freqs)[lower_5_quintile_i]

        upper_bound_to_bootstrap_upper_bar[upper_bound] = upper_bar
        upper_bound_to_bootstrap_lower_bar[upper_bound] = lower_bar

    upper_bars = []
    lower_bars = []
    for i, upper_bound in enumerate(upper_bounds):
        if upper_bound_to_bootstrap_upper_bar[upper_bound] == 0:
            upper_bars.append(0)
            lower_bars.append(0)
        else:
            # The error bar arguments need to be the distance from the data point, not the y-value
            upper_bars.append(abs(conf_x[i] - upper_bound_to_bootstrap_upper_bar[upper_bound]))
            lower_bars.append(abs(conf_x[i] - upper_bound_to_bootstrap_lower_bar[upper_bound]))

    # sns.set(font_scale=2)
    fig, ax = plt.subplots()
    ax.errorbar(conf_x, conf_x, label="Perfect classifier calibration")

    new_conf_x = []
    new_accs = []
    for i, bars in enumerate(zip(lower_bars, upper_bars)):
        if bars[0] == 0 and bars[1] == 0:
            continue
        new_conf_x.append(conf_x[i])
        new_accs.append(accs[i])

    # print("ECE: %g" % ece)
    ax.plot(new_conf_x, new_accs, label="Accuracy", color="#33a1b9")
    ax.set_ylim([0, 1])
    ax.set_xlim([0, 1])
    plt.title(" ECE: %.2f%%" % (ece * 100))
    plt.ylabel('acc.')
    plt.xlabel('probability')

    plt.savefig(save_dir[:-4] + '2.png')
    # plt.show()
    plt.close()

    fig, ax = plt.subplots()
    ax.errorbar([0, 1], [0, 1], label="Perfect classifier calibration")
    # ax.plot(new_conf_x, new_accs, '-o', label="Accuracy", color="black")
    bars = ax.bar(upper_bounds - 0.025, accs, width=bin_size, 
           label="Accuracy", color="#33a1b9", edgecolor='gray', align='center')
    for bar, alpha in zip(bars, bin_counts):
        bar.set_alpha(0.05 + alpha / (sum(bin_counts)*10/9))
    ax.set_ylim([0, 1])
    ax.set_xlim([0, 1])
    plt.title("ECE: %.2f%%" % (ece * 100), fontsize=20)
    plt.ylabel('acc.')
    plt.xlabel('probability')
    # fig.savefig("reliability.tif", format='tif', bbox_inches='tight', dpi=1200)

    # if args is not None and args.load_path:
    #     fig.savefig(args.load_path + "_calibration.png")
        # fig.savefig(args.load_path + "_calibration.eps", format='eps', bbox_inches='tight', dpi=1200)

    plt.savefig(save_dir)
    plt.close()
    wandb.log({
        f"calibration/{save_dir.split('/')[-1][:-4]}": wandb.Image(save_dir)
    })


def calculate_count_bin(conf_thresh_lower, conf_thresh_upper, conf):
    return len([x for x in conf if x > conf_thresh_lower and x <= conf_thresh_upper])


def compute_accuracy(conf_thresh_lower, conf_thresh_upper, conf, pred, true):

    filtered_tuples = [x for x in zip(pred, true, conf) if x[2] > conf_thresh_lower and x[2] <= conf_thresh_upper]
    if len(filtered_tuples) < 1:
        return 0, 0, 0
    else:
        correct = len([x for x in filtered_tuples if x[0] == x[1]])
        avg_conf = sum([x[2] for x in filtered_tuples]) / len(filtered_tuples)
        accuracy = float(correct) / len(filtered_tuples)
        perc_of_data = float(len(filtered_tuples)) / len(conf)
        return accuracy, perc_of_data, avg_conf


def compute_accuracy2(conf_thresh_lower, conf_thresh_upper, conf, pred, true):

    num_classes = max(true)
    filtered_tuples = [x for x in zip(pred, true, conf) if x[2] > conf_thresh_lower and x[2] <= conf_thresh_upper]
    if len(filtered_tuples) < 1:
        return 0, 0, 0
    else:
        corrects = []
        acc = []
        for i in range(num_classes):
            predict = len([x for x in filtered_tuples if x[0] == i])
            category = len([x for x in filtered_tuples if x[1] == i])
            correct = len([x for x in filtered_tuples if x[0] == i and x[0] == x[1]])
            if category == 0:
                accuracy = 0
            else:
                accuracy = float(correct) / category
            acc.append(accuracy)
            # print("category %d: predict num: %d, ground truth num: %d, correct: %d, %.4f" % (i, predict, category, correct, accuracy))
        avg_conf = sum([x[2] for x in filtered_tuples]) / len(filtered_tuples)
        perc_of_data = float(len(filtered_tuples)) / len(conf)
        accuracy = sum(acc) / num_classes
        return accuracy, perc_of_data, avg_conf


class ECELoss(nn.Module):
    """
    Calculates the Expected Calibration Error of a model.
    (This isn't necessary for temperature scaling, just a cool metric).
    The input to this loss is the logits of a model, NOT the softmax scores.
    This divides the confidence outputs into equally-sized interval bins.
    In each bin, we compute the confidence gap:
    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |
    We then return a weighted average of the gaps, based on the number
    of samples in each bin
    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
    "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
    2015.
    """
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(ECELoss, self).__init__()
        # bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        bin_boundaries = torch.tensor([0,
            0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4,
            0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9,
            0.95, 0.96, 0.97, 0.98, 0.99, 1.0
        ])
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]
        self.bin_p = []
        self.avg_conf = []
        self.accs = []
        self.ece = []
        self.diff = []

    def update(self, prob=0, conf=0, acc=0, ece=0, diff=0):
        self.bin_p.append(prob)
        self.avg_conf.append(conf)
        self.accs.append(acc)
        self.ece.append(ece)
        self.diff.append(diff)

    def get_bin(self):
        return self.bin_p, self.avg_conf, self.ece, self.accs, self.diff

    def init(self):
        self.bin_p = []
        self.avg_conf = []
        self.accs = []
        self.ece = []
        self.diff = []
        
    def forward(self, logits, labels):
        softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)

        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece_bin = torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
                ece += ece_bin
                diff = avg_confidence_in_bin - accuracy_in_bin
                self.update(
                    prop_in_bin.item(), avg_confidence_in_bin.item(), 
                    accuracy_in_bin.item(), ece_bin.item(), diff.item())
            else:
                self.update()
        return ece    